import json
import openai
from tqdm import tqdm
import pandas as pd
import argparse
import os
import sys
from typing import Dict, List, Optional
from openai import AzureOpenAI
import random  # For sampling distractor reasons

def parse_args() -> argparse.Namespace:
    """Parse command-line arguments for parallel execution."""
    parser = argparse.ArgumentParser(
        description="Run persona-based topic classification over a slice of the dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="topic_results", help="Directory to write JSON results.")
    parser.add_argument("--csv_path", type=str, required=True, help="Path to the input CSV with columns video_id,story")
    parser.add_argument("--annotation_path", type=str, default="reaction_annotation.json", help="Path to reaction_annotation.json for sampling distractor reasons")
    return parser.parse_args()

# ---------------------------------------------------------------------------
# Zero-shot system prompt for best-reason selection
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = (
    "You will be given the STORY of a video advertisement and a numbered list "
    "of five candidate reasons a viewer might give for taking the recommended action. "
    "Choose EXACTLY ONE reason that best explains why a viewer should take the action, and output that reason verbatim. "
    "Do not output any additional text—just the chosen reason."
)

# CSV expected columns: video_id, story, reasons (JSON list or ';'-separated)

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    # Setup Azure OpenAI client
    api_version = "2024-02-15-preview"
    config_dict: Dict[str, str] = {
        "api_key": os.getenv("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY"),
        "api_version": api_version,
        "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT", "https://your-azure-openai-endpoint/"),
    }
    client = AzureOpenAI(
        api_key=config_dict["api_key"],
        api_version=config_dict["api_version"],
        azure_endpoint=config_dict["azure_endpoint"],
    )

    # -------------------------------------------------------------------
    # Load full reaction annotation to draw distractor reasons
    # -------------------------------------------------------------------
    try:
        with open(args.annotation_path, "r") as f:
            annotation_data = json.load(f)
    except Exception as e:
        print(f"Error reading annotation JSON {args.annotation_path}: {e}")
        sys.exit(1)

    # Flatten all reasons into a single pool we can sample from
    all_reasons_pool = [reason for reasons in annotation_data.values() for reason in reasons]

    # Load CSV data
    try:
        df = pd.read_csv(args.csv_path)
    except Exception as e:
        print(f"Error reading CSV {args.csv_path}: {e}")
        sys.exit(1)

    all_records = df.to_dict(orient='records')

    # Determine slice for this run
    start_idx = args.start
    end_idx = len(all_records) - 1 if args.end is None else min(args.end, len(all_records) - 1)
    slice_records = all_records[start_idx : end_idx + 1]

    print(f"Processing slice {start_idx}–{end_idx} (n={len(slice_records)})")

    results = []
    output_path = os.path.join(args.output_dir, f"topic_results_{start_idx}_{end_idx}.json")

    for rec in tqdm(slice_records, desc=f"Persona-Topic Eval {start_idx}-{end_idx}"):
        try:
            video_id = str(rec.get('video_id', '')).strip()
            story_text = rec.get('story', '')
            # -------------------------------------------------------------------
            # Retrieve correct reasons (5) for this video
            # Priority: annotation file > CSV column fallback
            # -------------------------------------------------------------------

            correct_reasons = []

            # 1) annotation JSON
            if video_id in annotation_data:
                correct_reasons = annotation_data[video_id]

            # 2) fallback CSV column
            if not correct_reasons:
                reasons_raw = rec.get('reasons', '')
                try:
                    correct_reasons = json.loads(reasons_raw) if isinstance(reasons_raw, str) else reasons_raw
                except Exception:
                    correct_reasons = [r.strip() for r in str(reasons_raw).split(';') if r.strip()]

            # Clean list
            if isinstance(correct_reasons, str):
                correct_reasons = [correct_reasons]
            correct_reasons = [r for r in correct_reasons if r]

            if not correct_reasons:
                print(f"No reasons for id {video_id}; skipping")
                continue

            # Build candidate list: 5 correct + 25 random distractors
            distractor_pool = [r for r in all_reasons_pool if r not in correct_reasons]
            num_distractors = 25 if len(distractor_pool) >= 25 else len(distractor_pool)
            distractor_reasons = random.sample(distractor_pool, num_distractors)

            candidate_reasons = correct_reasons + distractor_reasons
            random.shuffle(candidate_reasons)

            cleaned_text = ' '.join(str(story_text).split()).replace('\n', '').replace('\f', '')

            # Build prompt with candidate reasons list
            reasons_block = "\n".join(f"{i+1}. {r}" for i, r in enumerate(candidate_reasons))
            user_content = (
                f"Story:\n{cleaned_text}\n\nList of reasons:\n{reasons_block}\n\n"
                "Return exactly one line:\nAnswer: <reason>"
            )

            messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_content},
            ]

            try:
                response = client.chat.completions.create(
                    model="gpt-4o",
                    messages=messages,
                    max_tokens=20,
                    temperature=0.0,
                    n=1,
                )
                raw_resp = response.choices[0].message.content.strip()

                # Try to extract after 'Answer:' if provided
                import re as _re
                ans_match = _re.search(r"(?i)^answer:\s*(.+)$", raw_resp, _re.MULTILINE)
                chosen_reason = ans_match.group(1).strip() if ans_match else raw_resp.strip()

                # If answer is a digit, map to candidate reasons
                if chosen_reason.isdigit():
                    idx_int = int(chosen_reason)
                    if 1 <= idx_int <= len(candidate_reasons):
                        chosen_reason = candidate_reasons[idx_int-1]
            except Exception as e:
                print(f"Error during OpenAI call for key {video_id}: {e}")
                chosen_reason = "error_api"

            # Store results
            result_item = {
                'video_id': video_id,
                'url': f"https://www.youtube.com/watch?v={video_id}" if video_id else "",
                'story': cleaned_text,
                'predicted_reason': chosen_reason,
                'candidate_reasons': candidate_reasons,
                'correct_reasons': correct_reasons,
            }
            results.append(result_item)
            
            # Incremental save
            with open(output_path, 'w') as f:
                json.dump(results, f, indent=4)

        except Exception as e:
            print(f"Error processing key {video_id}: {e}")
            continue

    print(f"Finished processing. Results saved to {output_path}")

if __name__ == "__main__":
    main()




